# Libraries
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.4     ✔ readr     2.1.5
## ✔ forcats   1.0.0     ✔ stringr   1.5.2
## ✔ ggplot2   4.0.0     ✔ tibble    3.3.0
## ✔ lubridate 1.9.4     ✔ tidyr     1.3.1
## ✔ purrr     1.1.0     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(mgcv)
## Loading required package: nlme
## 
## Attaching package: 'nlme'
## 
## The following object is masked from 'package:dplyr':
## 
##     collapse
## 
## This is mgcv 1.9-3. For overview type 'help("mgcv-package")'.
# Import the data subsets
RHP_RHH_df <- read_csv("RHP_RHH_bip.csv")
## Rows: 144820 Columns: 42
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (15): gameid, pitcher, pitcherthrows, batter, batterside, pitchresult, ...
## dbl  (26): ab, pitchnum, inning, teambat, balls, strikes, outs, visscore, ho...
## time  (1): GameDate
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
RHP_LHH_df <- read_csv("RHP_LHH_bip.csv")
## Rows: 123442 Columns: 42
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (15): gameid, pitcher, pitcherthrows, batter, batterside, pitchresult, ...
## dbl  (26): ab, pitchnum, inning, teambat, balls, strikes, outs, visscore, ho...
## time  (1): GameDate
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
LHP_RHH_df <- read_csv("LHP_RHH_bip.csv")
## Rows: 72971 Columns: 42
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (15): gameid, pitcher, pitcherthrows, batter, batterside, pitchresult, ...
## dbl  (26): ab, pitchnum, inning, teambat, balls, strikes, outs, visscore, ho...
## time  (1): GameDate
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
LHP_LHH_df <- read_csv("LHP_LHH_bip.csv")
## Rows: 26371 Columns: 42
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (15): gameid, pitcher, pitcherthrows, batter, batterside, pitchresult, ...
## dbl  (26): ab, pitchnum, inning, teambat, balls, strikes, outs, visscore, ho...
## time  (1): GameDate
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
GAM_RR <- gam(GIDP_SweetSpot ~
           s(spinrate) +
           s(relspeed) +
           s(inducedvertbreak) +
           s(horzbreak) +
           ti(platelocside, platelocheight) +
           s(platelocside) + s(platelocheight) +
           ti(initposx, initposz) +
           s(initposx) + s(initposz), 
         data=RHP_RHH_df,
         family = binomial,
         method = "REML")

summary(GAM_RR)
## 
## Family: binomial 
## Link function: logit 
## 
## Formula:
## GIDP_SweetSpot ~ s(spinrate) + s(relspeed) + s(inducedvertbreak) + 
##     s(horzbreak) + ti(platelocside, platelocheight) + s(platelocside) + 
##     s(platelocheight) + ti(initposx, initposz) + s(initposx) + 
##     s(initposz)
## 
## Parametric coefficients:
##              Estimate Std. Error z value Pr(>|z|)    
## (Intercept) -1.655941   0.008127  -203.8   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Approximate significance of smooth terms:
##                                   edf Ref.df   Chi.sq p-value    
## s(spinrate)                     3.568  4.577   15.083 0.00728 ** 
## s(relspeed)                     4.256  5.362  245.481 < 2e-16 ***
## s(inducedvertbreak)             6.208  7.404  798.282 < 2e-16 ***
## s(horzbreak)                    4.645  5.784  195.252 < 2e-16 ***
## ti(platelocside,platelocheight) 8.861 10.144  161.537 < 2e-16 ***
## s(platelocside)                 7.654  8.543  385.655 < 2e-16 ***
## s(platelocheight)               3.436  4.373 1002.230 < 2e-16 ***
## ti(initposx,initposz)           2.349  2.862   10.512 0.01578 *  
## s(initposx)                     1.159  1.302    1.152 0.42889    
## s(initposz)                     1.116  1.211   32.486 < 2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## R-sq.(adj) =  0.0274   Deviance explained = 3.15%
## -REML =  64042  Scale est. = 1         n = 144820
GAM_RL <- gam(GIDP_SweetSpot ~
           s(spinrate) +
           s(relspeed) +
           s(inducedvertbreak) +
           s(horzbreak) +
           ti(platelocside, platelocheight) +
           s(platelocside) + s(platelocheight) +
           ti(initposx, initposz) +
           s(initposx) + s(initposz), 
         data=RHP_LHH_df,
         family = binomial,
         method = "REML")

summary(GAM_RL)
## 
## Family: binomial 
## Link function: logit 
## 
## Formula:
## GIDP_SweetSpot ~ s(spinrate) + s(relspeed) + s(inducedvertbreak) + 
##     s(horzbreak) + ti(platelocside, platelocheight) + s(platelocside) + 
##     s(platelocheight) + ti(initposx, initposz) + s(initposx) + 
##     s(initposz)
## 
## Parametric coefficients:
##              Estimate Std. Error z value Pr(>|z|)    
## (Intercept) -1.795205   0.009352    -192   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Approximate significance of smooth terms:
##                                   edf Ref.df  Chi.sq  p-value    
## s(spinrate)                     5.244  6.489  60.653  < 2e-16 ***
## s(relspeed)                     5.667  6.897 540.268  < 2e-16 ***
## s(inducedvertbreak)             5.626  6.831 589.107  < 2e-16 ***
## s(horzbreak)                    7.674  8.575 152.750  < 2e-16 ***
## ti(platelocside,platelocheight) 6.896  8.348  60.948  < 2e-16 ***
## s(platelocside)                 6.087  7.290 880.565  < 2e-16 ***
## s(platelocheight)               3.751  4.737 482.716  < 2e-16 ***
## ti(initposx,initposz)           5.219  6.476  29.635 5.62e-05 ***
## s(initposx)                     1.904  2.416  40.836  < 2e-16 ***
## s(initposz)                     1.837  2.330   7.551   0.0469 *  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## R-sq.(adj) =  0.0256   Deviance explained = 3.09%
## -REML =  50753  Scale est. = 1         n = 123442
GAM_LR <- gam(GIDP_SweetSpot ~
           s(spinrate) +
           s(relspeed) +
           s(inducedvertbreak) +
           s(horzbreak) +
           ti(platelocside, platelocheight) +
           s(platelocside) + s(platelocheight) +
           ti(initposx, initposz) +
           s(initposx) + s(initposz), 
         data=LHP_RHH_df,
         family = binomial,
         method = "REML")

summary(GAM_LR)
## 
## Family: binomial 
## Link function: logit 
## 
## Formula:
## GIDP_SweetSpot ~ s(spinrate) + s(relspeed) + s(inducedvertbreak) + 
##     s(horzbreak) + ti(platelocside, platelocheight) + s(platelocside) + 
##     s(platelocheight) + ti(initposx, initposz) + s(initposx) + 
##     s(initposz)
## 
## Parametric coefficients:
##             Estimate Std. Error z value Pr(>|z|)    
## (Intercept) -1.74453    0.01163    -150   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Approximate significance of smooth terms:
##                                   edf Ref.df  Chi.sq p-value    
## s(spinrate)                     3.927  4.990  47.055 < 2e-16 ***
## s(relspeed)                     4.083  5.143 199.055 < 2e-16 ***
## s(inducedvertbreak)             5.274  6.384 312.427 < 2e-16 ***
## s(horzbreak)                    7.191  8.299  31.473 0.00014 ***
## ti(platelocside,platelocheight) 7.417  8.789  55.343 < 2e-16 ***
## s(platelocside)                 6.552  7.715 274.461 < 2e-16 ***
## s(platelocheight)               2.931  3.732 314.141 < 2e-16 ***
## ti(initposx,initposz)           1.746  2.107   1.269 0.50887    
## s(initposx)                     1.003  1.007   7.049 0.00807 ** 
## s(initposz)                     1.851  2.346   4.213 0.12602    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## R-sq.(adj) =  0.0226   Deviance explained =  2.7%
## -REML =  30984  Scale est. = 1         n = 72971
GAM_LL <- gam(GIDP_SweetSpot ~
           s(spinrate) +
           s(relspeed) +
           s(inducedvertbreak) +
           s(horzbreak) +
           ti(platelocside, platelocheight) +
           s(platelocside) + s(platelocheight) +
           ti(initposx, initposz) +
           s(initposx) + s(initposz), 
         data=LHP_LHH_df,
         family = binomial,
         method = "REML")

summary(GAM_LL)
## 
## Family: binomial 
## Link function: logit 
## 
## Formula:
## GIDP_SweetSpot ~ s(spinrate) + s(relspeed) + s(inducedvertbreak) + 
##     s(horzbreak) + ti(platelocside, platelocheight) + s(platelocside) + 
##     s(platelocheight) + ti(initposx, initposz) + s(initposx) + 
##     s(initposz)
## 
## Parametric coefficients:
##             Estimate Std. Error z value Pr(>|z|)    
## (Intercept)  -1.6759     0.0193  -86.83   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Approximate significance of smooth terms:
##                                   edf Ref.df  Chi.sq  p-value    
## s(spinrate)                     1.006  1.012   1.236    0.268    
## s(relspeed)                     3.174  4.070  33.464 6.82e-07 ***
## s(inducedvertbreak)             5.494  6.678 161.309  < 2e-16 ***
## s(horzbreak)                    5.452  6.660  69.627  < 2e-16 ***
## ti(platelocside,platelocheight) 5.045  6.260  37.916  < 2e-16 ***
## s(platelocside)                 2.823  3.602  83.465  < 2e-16 ***
## s(platelocheight)               1.002  1.004 106.452  < 2e-16 ***
## ti(initposx,initposz)           1.005  1.010   0.074    0.797    
## s(initposx)                     1.002  1.004   1.322    0.251    
## s(initposz)                     1.002  1.003  19.029 1.27e-05 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## R-sq.(adj) =  0.024   Deviance explained = 2.95%
## -REML =  11464  Scale est. = 1         n = 26371
GAM_RR$sp
##                      s(spinrate)                      s(relspeed) 
##                     7.758323e+00                     2.268960e+00 
##              s(inducedvertbreak)                     s(horzbreak) 
##                     1.036654e+00                     3.942775e+00 
## ti(platelocside,platelocheight)1 ti(platelocside,platelocheight)2 
##                     1.300854e-01                     6.901210e+00 
##                  s(platelocside)                s(platelocheight) 
##                     2.422814e-01                     1.213659e+01 
##           ti(initposx,initposz)1           ti(initposx,initposz)2 
##                     1.778650e+04                     4.819995e+01 
##                      s(initposx)                      s(initposz) 
##                     6.083816e+02                     2.060975e+02
GAM_RL$sp
##                      s(spinrate)                      s(relspeed) 
##                        1.4862810                        0.8882628 
##              s(inducedvertbreak)                     s(horzbreak) 
##                        1.2006682                        0.1771016 
## ti(platelocside,platelocheight)1 ti(platelocside,platelocheight)2 
##                        3.1108314                        4.4345738 
##                  s(platelocside)                s(platelocheight) 
##                        1.0184365                        7.0181416 
##           ti(initposx,initposz)1           ti(initposx,initposz)2 
##                        6.9994421                        7.7315587 
##                      s(initposx)                      s(initposz) 
##                       35.3366361                        9.9384377
GAM_LR$sp
##                      s(spinrate)                      s(relspeed) 
##                     3.543508e+00                     1.390835e+00 
##              s(inducedvertbreak)                     s(horzbreak) 
##                     1.168677e+00                     8.071295e-02 
## ti(platelocside,platelocheight)1 ti(platelocside,platelocheight)2 
##                     3.673251e-02                     9.945469e+00 
##                  s(platelocside)                s(platelocheight) 
##                     3.034135e-01                     1.432097e+01 
##           ti(initposx,initposz)1           ti(initposx,initposz)2 
##                     2.426045e+05                     6.579429e+01 
##                      s(initposx)                      s(initposz) 
##                     1.810714e+04                     2.282283e+01
GAM_LL$sp
##                      s(spinrate)                      s(relspeed) 
##                     2.322189e+03                     2.485375e+00 
##              s(inducedvertbreak)                     s(horzbreak) 
##                     1.289867e-01                     2.700248e-01 
## ti(platelocside,platelocheight)1 ti(platelocside,platelocheight)2 
##                     3.448309e-02                     1.657037e+01 
##                  s(platelocside)                s(platelocheight) 
##                     3.630099e+00                     7.839531e+03 
##           ti(initposx,initposz)1           ti(initposx,initposz)2 
##                     4.186264e+03                     1.733346e+05 
##                      s(initposx)                      s(initposz) 
##                     1.581901e+04                     8.643324e+03
plot(GAM_RR)

Visualizations

library(ggplot2)
library(RColorBrewer)
library(rlang)
## 
## Attaching package: 'rlang'
## The following objects are masked from 'package:purrr':
## 
##     %@%, flatten, flatten_chr, flatten_dbl, flatten_int, flatten_lgl,
##     flatten_raw, invoke, splice
plot_gam_heatmap <- function(model, xvar, zvar, data, n = 200) {
  # Build a dense grid
  grid <- expand.grid(
    x = seq(min(data[[xvar]], na.rm = TRUE),
            max(data[[xvar]], na.rm = TRUE),
            length.out = n),
    z = seq(min(data[[zvar]], na.rm = TRUE),
            max(data[[zvar]], na.rm = TRUE),
            length.out = n)
  )
  names(grid) <- c(xvar, zvar)

  # Fill in other predictors
  other_vars <- setdiff(names(data), c(xvar, zvar))
  for (v in other_vars) {
    if (is.numeric(data[[v]])) {
      grid[[v]] <- mean(data[[v]], na.rm = TRUE)
    } else {
      grid[[v]] <- unique(data[[v]])[1]
    }
  }

  # Predict
  grid$fit <- predict(model, newdata = grid, type = "response")

  # Explicitly map fill inside geom_tile()
  ggplot(grid, aes(x = !!sym(xvar), y = !!sym(zvar))) +
    geom_tile(aes(fill = fit)) +
    geom_contour(aes(z = fit), color = "black", linewidth = 0.3) +
    scale_fill_gradientn(colors = rev(brewer.pal(11, "RdYlBu"))) +
    coord_equal() +
    labs(
      x = xvar,
      y = zvar,
      fill = "Predicted Probability",
      title = paste("Predicted Sweet-Spot Probability by", xvar, "and", zvar)
    ) +
    theme_minimal(base_size = 14)
}



plot_gam_heatmap(GAM_LL, "platelocside", "platelocheight", LHP_LHH_df)

plot_gam_heatmap(GAM_RR, "platelocside", "platelocheight", RHP_RHH_df)

plot_gam_heatmap(GAM_RL, "platelocside", "platelocheight", RHP_LHH_df)

plot_gam_heatmap(GAM_LR, "platelocside", "platelocheight", LHP_RHH_df)

Visualizations

plot_gam_heatmap_kzone <- function(model, data, 
                                xvar = "platelocside", 
                                zvar = "platelocheight",
                                n = 200,
                                side_range = c(-0.83, 0.83),
                                height_range = c(1.5, 3.5)) {
  # grid only within the strike zone
  grid <- expand.grid(
    x = seq(side_range[1], side_range[2], length.out = n),
    z = seq(height_range[1], height_range[2], length.out = n)
  )
  names(grid) <- c(xvar, zvar)

  # keep other predictors fixed
  other_vars <- setdiff(names(data), c(xvar, zvar))
  for (v in other_vars) {
    if (is.numeric(data[[v]])) {
      grid[[v]] <- mean(data[[v]], na.rm = TRUE)
    } else {
      grid[[v]] <- unique(data[[v]])[1]
    }
  }

  # predict only inside strike zone
  grid$fit <- predict(model, newdata = grid, type = "response")

  ggplot(grid, aes(x = !!sym(xvar), y = !!sym(zvar))) +
    geom_tile(aes(fill = fit)) +
    geom_contour(aes(z = fit), color = "black", linewidth = 0.3) +
    scale_fill_gradientn(colors = rev(brewer.pal(11, "RdYlBu"))) +
    coord_equal() +
    scale_y_continuous(limits = c(1.5, 3.5)) +   # normal (not reversed)
    geom_rect(aes(xmin = -0.83, xmax = 0.83, ymin = 1.5, ymax = 3.5),
              color = "white", fill = NA, linewidth = 1) +
    labs(
      x = "Plate Side (ft)",
      y = "Plate Height (ft)",
      fill = "Predicted Probability",
      title = "Predicted Sweet-Spot Probability within Strike Zone"
    ) +
    theme_minimal(base_size = 14)
}

plot_gam_heatmap_kzone(GAM_LL, LHP_LHH_df)
## Warning: Removed 400 rows containing missing values or values outside the scale range
## (`geom_tile()`).

plot_gam_heatmap_kzone(GAM_RR, RHP_RHH_df)
## Warning: Removed 400 rows containing missing values or values outside the scale range
## (`geom_tile()`).

plot_gam_heatmap_kzone(GAM_RL, RHP_LHH_df)
## Warning: Removed 400 rows containing missing values or values outside the scale range
## (`geom_tile()`).

plot_gam_heatmap_kzone(GAM_LR, LHP_RHH_df)
## Warning: Removed 400 rows containing missing values or values outside the scale range
## (`geom_tile()`).

library(gratia)
## 
## Attaching package: 'gratia'
## The following object is masked from 'package:stringr':
## 
##     boundary
draw(GAM_LL, select = "s(relspeed)", residuals = TRUE)

draw(GAM_LL, select = "s(spinrate)", residuals = TRUE)

draw(GAM_LL, select = "s(inducedvertbreak)", residuals = TRUE)

draw(GAM_LL, select = "s(horzbreak)", residuals = TRUE)

Pitch Type Analysis

Each pitch is assigned to a cluster (0-4) based on physical pitch traits

RHP_RHH_by_cluster <- read_csv('RHP_RHH_GIDP_ByPitchCluster.csv')
## Rows: 144820 Columns: 43
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (15): gameid, pitcher, pitcherthrows, batter, batterside, pitchresult, ...
## dbl  (27): ab, pitchnum, inning, teambat, balls, strikes, outs, visscore, ho...
## time  (1): GameDate
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
RHP_LHH_by_cluster <- read_csv('RHP_LHH_GIDP_ByPitchCluster.csv')
## Rows: 123442 Columns: 43
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (15): gameid, pitcher, pitcherthrows, batter, batterside, pitchresult, ...
## dbl  (27): ab, pitchnum, inning, teambat, balls, strikes, outs, visscore, ho...
## time  (1): GameDate
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
LHP_RHH_by_cluster <- read_csv('LHP_RHH_GIDP_ByPitchCluster.csv')
## Rows: 72971 Columns: 43
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (15): gameid, pitcher, pitcherthrows, batter, batterside, pitchresult, ...
## dbl  (27): ab, pitchnum, inning, teambat, balls, strikes, outs, visscore, ho...
## time  (1): GameDate
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
LHP_LHH_by_cluster <- read_csv('LHP_LHH_GIDP_ByPitchCluster.csv')
## Rows: 26371 Columns: 43
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (15): gameid, pitcher, pitcherthrows, batter, batterside, pitchresult, ...
## dbl  (27): ab, pitchnum, inning, teambat, balls, strikes, outs, visscore, ho...
## time  (1): GameDate
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Inference by pitch type + location

Hold pitch traits and release point constant, while changing pitch location.

# Average pitch type traits
cluster_summary_rr <- RHP_RHH_by_cluster %>%
  group_by(Cluster) %>%
  summarise(
    spinrate = mean(spinrate, na.rm = TRUE),
    relspeed = mean(relspeed, na.rm = TRUE),
    inducedvertbreak = mean(inducedvertbreak, na.rm = TRUE),
    horzbreak = mean(horzbreak, na.rm = TRUE),
    initposx = mean(initposx, na.rm = TRUE),
    initposz = mean(initposz, na.rm = TRUE)
  )

# strikezone
k_zone_height_max <- 3.67
k_zone_heightmin <- 1.52
sides <- c(-0.83, 0.83)

plot_heatmap_per_pitchtype <- function(model, data,
                                       pitch_cluster) {
  k_zone_height_max <- 3.67
  k_zone_height_min <- 1.52
  sides <- c(-0.83, 0.83)
  
  cluster_summary_rr <- data %>%
  filter(Cluster == pitch_cluster) %>%
  summarise(
    spinrate = mean(spinrate, na.rm = TRUE),
    relspeed = mean(relspeed, na.rm = TRUE),
    inducedvertbreak = mean(inducedvertbreak, na.rm = TRUE),
    horzbreak = mean(horzbreak, na.rm = TRUE),
    initposx = mean(initposx, na.rm = TRUE),
    initposz = mean(initposz, na.rm = TRUE)
  )
  
  grid <- expand.grid(
    platelocside = seq(sides[1], sides[2], length.out = 200),
    platelocheight = seq(k_zone_height_min, k_zone_height_max, length.out = 200)
  )
  
  grid$spinrate         <- cluster_summary_rr$spinrate
  grid$relspeed         <- cluster_summary_rr$relspeed
  grid$inducedvertbreak <- cluster_summary_rr$inducedvertbreak
  grid$horzbreak        <- cluster_summary_rr$horzbreak
  grid$initposx         <- cluster_summary_rr$initposx
  grid$initposz         <- cluster_summary_rr$initposz
  grid$Cluster          <- pitch_cluster
  
  grid$pred_prob <- predict(model, newdata = grid, type = "response")
  
  p <- ggplot(grid, aes(platelocside, platelocheight, fill = pred_prob)) +
    geom_tile() +
    scale_fill_viridis_c(option = "C", direction = -1) +
    geom_rect(
      aes(xmin = sides[1], xmax = sides[2],
          ymin = k_zone_height_min, ymax = k_zone_height_max),
      color = "white", fill = NA, linewidth = 0.8
    ) +
    coord_equal() +
    labs(
      title = paste("Predicted Sweet-Spot Probability — Pitch Type", pitch_cluster),
      x = "Plate Side (ft)",
      y = "Plate Height (ft)",
      fill = "Predicted Prob."
    ) +
    theme_minimal(base_size = 14)
  
  return(p)
}

unique(RHP_RHH_by_cluster$Cluster)
## [1] 1 0 3 4 2
plot_heatmap_per_pitchtype(GAM_RR, RHP_RHH_by_cluster, 0)

plot_heatmap_per_pitchtype(GAM_RR, RHP_RHH_by_cluster, 1)

plot_heatmap_per_pitchtype(GAM_RR, RHP_RHH_by_cluster, 2)

plot_heatmap_per_pitchtype(GAM_RR, RHP_RHH_by_cluster, 3)

plot_heatmap_per_pitchtype(GAM_RR, RHP_RHH_by_cluster, 4)

make_heatmap_grid <- function(data, model) {
  # --- strike-zone limits ---
  k_zone_height_min <- 1.52
  k_zone_height_max <- 3.67
  sides <- c(-0.83, 0.83)
  
  # --- cluster means ---
  cluster_summary <- data %>%
    group_by(Cluster) %>%
    summarise(
      spinrate         = mean(spinrate, na.rm = TRUE),
      relspeed         = mean(relspeed, na.rm = TRUE),
      inducedvertbreak = mean(inducedvertbreak, na.rm = TRUE),
      horzbreak        = mean(horzbreak, na.rm = TRUE),
      initposx         = mean(initposx, na.rm = TRUE),
      initposz         = mean(initposz, na.rm = TRUE)
    )
  
  # --- build grid for each cluster ---
  all_grids <- purrr::map_dfr(unique(cluster_summary$Cluster), function(c) {
    grid <- expand.grid(
      platelocside   = seq(sides[1], sides[2], length.out = 200),
      platelocheight = seq(k_zone_height_min, k_zone_height_max, length.out = 200)
    )
    
    # attach cluster traits
    traits <- cluster_summary[cluster_summary$Cluster == c, ]
    for (v in names(traits)) grid[[v]] <- traits[[v]]
    grid$Cluster <- c
    
    # --- predict safely ---
    suppressWarnings({
      grid$pred_prob <- tryCatch(
        as.numeric(predict(model, newdata = grid, type = "response")),
        error = function(e) rep(NA_real_, nrow(grid))
      )
    })
    
    # replace missing values with mean of valid ones
    if (all(is.na(grid$pred_prob))) grid$pred_prob <- 0
    mean_val <- mean(grid$pred_prob, na.rm = TRUE)
    grid$pred_prob[is.na(grid$pred_prob) | !is.finite(grid$pred_prob)] <- mean_val
    
    grid
  })
  
  # --- plot ---
  ggplot(all_grids, aes(platelocside, platelocheight, fill = pred_prob)) +
    geom_tile() +
    scale_fill_viridis_c(option = "C", direction = -1, limits = c(0.02, 0.5)) +
    facet_wrap(~ Cluster, ncol = 3) +
    coord_equal() +
    labs(
      title = "Predicted Sweet-Spot Probability by Pitch Cluster",
      x = "Plate Side (ft)", y = "Plate Height (ft)", fill = "Predicted Prob."
    ) +
    theme_minimal(base_size = 14)
}

make_heatmap_grid(RHP_RHH_by_cluster, GAM_RR)

RHP_RHH_by_cluster %>%
  group_by(Cluster) %>%
  summarise(
    spinrate = mean(spinrate, na.rm=TRUE),
    relspeed = mean(relspeed, na.rm=TRUE),
    inducedvertbreak = mean(inducedvertbreak, na.rm=TRUE),
    horzbreak = mean(horzbreak, na.rm=TRUE),
    initposx = mean(initposx, na.rm=TRUE),
    initposz = mean(initposz, na.rm=TRUE)
  )
## # A tibble: 5 × 7
##   Cluster spinrate relspeed inducedvertbreak horzbreak initposx initposz
##     <dbl>    <dbl>    <dbl>            <dbl>     <dbl>    <dbl>    <dbl>
## 1       0    2393.     86.5             3.69     -4.38    -1.68     5.76
## 2       1    2311.     94.6            15.8       7.39    -1.48     5.79
## 3       2    1518.     85.8             3.84     11.0     -1.58     5.68
## 4       3    2137.     93.2             7.54     15.1     -1.71     5.56
## 5       4    2602.     81.3            -4.50    -11.9     -1.80     5.73